Style Transfer
Table of Contents
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from keras.applications.vgg16 import VGG16
import cv2
h_image, w_image = 600, 1000
img_content = cv2.imread('./image_files/postech_flag.jpg')
img_content = cv2.cvtColor(img_content, cv2.COLOR_BGR2RGB)
img_content = cv2.resize(img_content, (w_image, h_image))
plt.figure(figsize = (10,8))
plt.imshow(img_content)
plt.axis('off')
plt.show()
img_style = cv2.imread('./image_files/la_muse.jpg')
img_style = cv2.cvtColor(img_style, cv2.COLOR_BGR2RGB)
img_style = cv2.resize(img_style, (w_image, h_image))
plt.figure(figsize = (10,8))
plt.imshow(img_style)
plt.axis('off')
plt.show()
model = VGG16(weights = 'imagenet')
model.summary()
vgg16_weights = model.get_weights()
# kernel size: [kernel_height, kernel_width, input_ch, output_ch]
weights = {
'conv1_1' : tf.constant(vgg16_weights[0]),
'conv1_2' : tf.constant(vgg16_weights[2]),
'conv2_1' : tf.constant(vgg16_weights[4]),
'conv2_2' : tf.constant(vgg16_weights[6]),
'conv3_1' : tf.constant(vgg16_weights[8]),
'conv3_2' : tf.constant(vgg16_weights[10]),
'conv3_3' : tf.constant(vgg16_weights[12]),
'conv4_1' : tf.constant(vgg16_weights[14]),
'conv4_2' : tf.constant(vgg16_weights[16]),
'conv4_3' : tf.constant(vgg16_weights[18]),
'conv5_1' : tf.constant(vgg16_weights[20]),
'conv5_2' : tf.constant(vgg16_weights[22]),
'conv5_3' : tf.constant(vgg16_weights[24]),
}
# bias size: [output_ch] or [neuron_size]
biases = {
'conv1_1' : tf.constant(vgg16_weights[1]),
'conv1_2' : tf.constant(vgg16_weights[3]),
'conv2_1' : tf.constant(vgg16_weights[5]),
'conv2_2' : tf.constant(vgg16_weights[7]),
'conv3_1' : tf.constant(vgg16_weights[9]),
'conv3_2' : tf.constant(vgg16_weights[11]),
'conv3_3' : tf.constant(vgg16_weights[13]),
'conv4_1' : tf.constant(vgg16_weights[15]),
'conv4_2' : tf.constant(vgg16_weights[17]),
'conv4_3' : tf.constant(vgg16_weights[19]),
'conv5_1' : tf.constant(vgg16_weights[21]),
'conv5_2' : tf.constant(vgg16_weights[23]),
'conv5_3' : tf.constant(vgg16_weights[25]),
}
# input layer: [1, image_height, image_width, channels]
input_content = tf.placeholder(tf.float32, [1, h_image, w_image, 3])
input_style = tf.placeholder(tf.float32, [1, h_image, w_image, 3])
def net(x, weights, biases):
# First convolution layer
conv1_1 = tf.nn.conv2d(x,
weights['conv1_1'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv1_1 = tf.nn.relu(tf.add(conv1_1, biases['conv1_1']))
conv1_2 = tf.nn.conv2d(conv1_1,
weights['conv1_2'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv1_2 = tf.nn.relu(tf.add(conv1_2, biases['conv1_2']))
maxp1 = tf.nn.max_pool(conv1_2,
ksize = [1, 2, 2, 1],
strides = [1, 2, 2, 1],
padding = 'VALID')
# Second convolution layer
conv2_1 = tf.nn.conv2d(maxp1,
weights['conv2_1'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv2_1 = tf.nn.relu(tf.add(conv2_1, biases['conv2_1']))
conv2_2 = tf.nn.conv2d(conv2_1,
weights['conv2_2'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv2_2 = tf.nn.relu(tf.add(conv2_2, biases['conv2_2']))
maxp2 = tf.nn.max_pool(conv2_2,
ksize = [1, 2, 2, 1],
strides = [1, 2, 2, 1],
padding = 'VALID')
# third convolution layer
conv3_1 = tf.nn.conv2d(maxp2,
weights['conv3_1'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv3_1 = tf.nn.relu(tf.add(conv3_1, biases['conv3_1']))
conv3_2 = tf.nn.conv2d(conv3_1,
weights['conv3_2'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv3_2 = tf.nn.relu(tf.add(conv3_2, biases['conv3_2']))
conv3_3 = tf.nn.conv2d(conv3_2,
weights['conv3_3'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv3_3 = tf.nn.relu(tf.add(conv3_3, biases['conv3_3']))
maxp3 = tf.nn.max_pool(conv3_3,
ksize = [1, 2, 2, 1],
strides = [1, 2, 2, 1],
padding = 'VALID')
# fourth convolution layer
conv4_1 = tf.nn.conv2d(maxp3,
weights['conv4_1'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv4_1 = tf.nn.relu(tf.add(conv4_1, biases['conv4_1']))
conv4_2 = tf.nn.conv2d(conv4_1,
weights['conv4_2'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv4_2 = tf.nn.relu(tf.add(conv4_2, biases['conv4_2']))
conv4_3 = tf.nn.conv2d(conv4_2,
weights['conv4_3'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv4_3 = tf.nn.relu(tf.add(conv4_3, biases['conv4_3']))
maxp4 = tf.nn.max_pool(conv4_3,
ksize = [1, 2, 2, 1],
strides = [1, 2, 2, 1],
padding = 'VALID')
# fifth convolution layer
conv5_1 = tf.nn.conv2d(maxp4,
weights['conv5_1'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv5_1 = tf.nn.relu(tf.add(conv5_1, biases['conv5_1']))
conv5_2 = tf.nn.conv2d(conv5_1,
weights['conv5_2'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv5_2 = tf.nn.relu(tf.add(conv5_2, biases['conv5_2']))
conv5_3 = tf.nn.conv2d(conv5_2,
weights['conv5_3'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv5_3 = tf.nn.relu(tf.add(conv5_3, biases['conv5_3']))
maxp5 = tf.nn.max_pool(conv5_3,
ksize = [1, 2, 2, 1],
strides = [1, 2, 2, 1],
padding = 'VALID')
return {
'conv1_1' : conv1_1,
'conv1_2' : conv1_2,
'conv2_1' : conv2_1,
'conv2_2' : conv2_2,
'conv3_1' : conv3_1,
'conv3_2' : conv3_2,
'conv3_3' : conv3_3,
'conv4_1' : conv4_1,
'conv4_2' : conv4_2,
'conv4_3' : conv4_3,
'conv5_1' : conv5_1,
'conv5_2' : conv5_2,
'conv5_3' : conv5_3
}
layers_style = ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1']
layers_content = ['conv4_2']
LR = 30
# composite image is the only variable that needs to be updated
input_gen = tf.Variable(tf.random_uniform([1, h_image, w_image, 3], maxval = 255))
def get_gram_matrix(conv_layer):
channels = conv_layer.get_shape().as_list()[3]
conv_layer = tf.reshape(conv_layer, (-1, channels))
gram_matrix = tf.matmul(tf.transpose(conv_layer), conv_layer)
return gram_matrix/((conv_layer.get_shape().as_list()[0])*channels)
def get_loss_style(gram_matrix_gen, gram_matrix_ref):
loss = tf.reduce_mean(tf.square(gram_matrix_gen - gram_matrix_ref))
return loss
def get_loss_content(gen_layer, ref_layer):
loss = tf.reduce_mean(tf.square(gen_layer - ref_layer))
return loss
features_style = net(input_style, weights, biases)
features_content = net(input_content, weights, biases)
features_gen = net(input_gen, weights, biases)
loss_style = 0
for key in layers_style:
loss_style += get_loss_style(get_gram_matrix(features_gen[key]), get_gram_matrix(features_style[key]))
loss_content = 0
for key in layers_content:
loss_content += get_loss_content(features_gen[key], features_content[key])
g = 1/(1e1)
loss_total = loss_content + g*loss_style
optm = tf.train.AdamOptimizer(LR).minimize(loss_total)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
n_iter = 1000
n_prt = 100
for itr in range(n_iter + 1):
sess.run(optm, feed_dict = {input_style: img_style[np.newaxis,:,:,:],
input_content: img_content[np.newaxis,:,:,:]})
if itr%n_prt == 0:
ls = sess.run(loss_style, feed_dict = {input_style: img_style[np.newaxis,:,:,:]})
lc = sess.run(loss_content, feed_dict = {input_content: img_content[np.newaxis,:,:,:]})
print('Iteration: {}'.format(itr))
print('Style loss: {}'.format(g*ls))
print('Content loss: {}\n'.format(lc))
image = sess.run(input_gen)
image = np.uint8(np.clip(np.round(image), 0, 255)).squeeze()
plt.figure(figsize = (10,8))
plt.imshow(image)
plt.axis('off')
plt.show()
def get_loss_TV(conv_layer):
loss = tf.reduce_mean(tf.abs(conv_layer[:,:,1:,:] - conv_layer[:,:,:-1,:])) \
+ tf.reduce_mean(tf.abs(conv_layer[:,1:,:,:] - conv_layer[:,:-1,:,:]))
return loss
loss_TV = get_loss_TV(input_gen)
loss_total = loss_content + loss_style + 100*loss_TV
optm = tf.train.AdamOptimizer(LR).minimize(loss_total)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
n_iter = 500
n_prt = 100
for itr in range(n_iter + 1):
sess.run(optm, feed_dict = {input_style : img_style[np.newaxis,:,:,:],
input_content : img_content[np.newaxis,:,:,:]})
if itr%n_prt == 0:
ls = sess.run(loss_style, feed_dict = {input_style : img_style[np.newaxis,:,:,:]})
lc = sess.run(loss_content, feed_dict = {input_content : img_content[np.newaxis,:,:,:]})
ltv = sess.run(loss_TV)
print('Iteration: {}'.format(itr))
print('Style loss: {}'.format(g*ls))
print('Content loss: {}'.format(lc))
print('TV loss: {}\n'.format(ltv))
image = sess.run(input_gen)
image = np.uint8(np.clip(np.round(image), 0, 255)).squeeze()
plt.figure(figsize = (10,8))
plt.imshow(image)
plt.axis('off')
plt.show()
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')